import torch
import torch.nn as nn
from .SubLayers import TPGablation
from .TransformerLayers import MultiHeadAttention, PositionwiseFeedForward


class ConvExpandAttr(nn.Module):
    '''
    [batch, n_route, n_time, 1] -> [batch, n_route, n_time, n_attr]
    '''

    def __init__(
        self,
        c_in,
        c_out,
        kernel_size,
        bias
    ):
        super().__init__()

        self.conv = nn.Conv2d(c_in, c_out, kernel_size, bias=bias)

    def forward(self, x):
        # [batch, n_route, n_time, 1] -> [batch, 1, n_route, n_time]
        x = x.permute(0, 3, 1, 2)
        # [batch, 1, n_route, n_time] -> [batch, n_attr, n_route, n_time]
        x = self.conv(x)
        # [batch, n_attr, n_route, n_time] -> [batch, n_route, n_time, n_attr]
        x = x.permute(0, 2, 3, 1)
        return x


class SpatioEnc(nn.Module):
    def __init__(
        self,
        n_route,
        n_attr=33,
        normal=True
    ):
        super().__init__()

        self.enc = nn.Parameter(torch.empty(n_route, n_attr))
        nn.init.xavier_uniform_(self.enc.data)
        # self.w = nn.Linear(n_route, n_attr)
        self.no = normal
        self.norm = nn.LayerNorm(n_attr, eps=1e-6)

    def forward(self, x):
        enc = self.enc
        x = x.permute(0, 2, 1, 3) + enc
        if self.no:
            x = self.norm(x)
        x = x.permute(0, 2, 1, 3)
        return x


class TempoEnc(nn.Module):
    def __init__(
        self,
        n_time,
        n_attr,
        normal=True
    ):
        super().__init__()

        self.time = n_time
        self.enc = nn.Embedding(n_time, n_attr)
        self.no = normal
        self.norm = nn.LayerNorm(n_attr, eps=1e-6)

    def forward(self, x, start=0, t_left=None):
        length = x.shape[-2]
        if t_left == None:
            enc = self.enc(torch.arange(start, start + length).cuda())
        else:
            enc = self.enc(torch.Tensor(t_left).long().cuda())
        x = x + enc
        if self.no:
            x = self.norm(x)
        return x


class MLP(nn.Module):
    def __init__(
        self,
        d_in,
        d_out=1
    ):
        super(MLP, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(d_in, d_in//2),
            nn.ReLU(inplace=True),
            nn.Linear(d_in//2, d_in//4),
            nn.ReLU(inplace=True),
            nn.Linear(d_in//4, d_out)
        )

    def forward(self, x):
        # [batch, n_route, n_his, n_attr] -> [batch, n_route, n_attr, n_his]
        x = x.permute(0, 1, 3, 2)
        # [batch, n_route, n_attr, n_his] -> [batch ,n_route, n_attr, 1]
        output = self.linear(x)
        # [batch ,n_route, n_attr, 1] -> [batch, n_route, 1, n_attr]
        output = output.permute(0, 1, 3, 2)
        return output


class EncoderLayer(nn.Module):
    def __init__(
        self,
        opt,
        spa_mask, tem_mask
    ):
        super().__init__()

        n_route, n_his, n_attr, n_hid = opt.n_route, opt.n_his, opt.n_attr, opt.n_hid

        dis_mat = opt.dis_mat

        self.tem_attn = MultiHeadAttention(
            opt.attn['head'], n_attr, opt.attn['d_k'], opt.attn['d_v'], opt.attn['drop_prob'])
        self.tem_mask = tem_mask

        if opt.ST['use']:
            n_head, d_q, d_k, d_c, kt, normal = opt.ST['n_head'], opt.ST[
                'd_q'], opt.ST['d_k'], opt.ST['d_c'], opt.ST['kt'], opt.ST['normal']
            self.stgc = STAttnGraphConv(
                n_route, n_his, n_attr, n_attr, dis_mat, n_head, d_q, d_k, d_c, kt, normal)
        elif opt.STO['use']:
            n_head, d_q, d_k, d_c, kt, normal = opt.STO['n_head'], opt.STO[
                'd_q'], opt.STO['d_k'], opt.STO['d_c'], opt.STO['kt'], opt.STO['normal']
            self.stgc = STAttnGraphConv_verO(
                n_route, n_his, n_attr, n_attr, dis_mat, n_head, d_q, d_k, d_c, kt, normal)
        elif opt.STdrop['use']:
            kt, droprate, temperature = opt.STdrop['kt'], opt.drop_prob, opt.STdrop['temperature']
            self.stgc = STAttnGraphConv_drop(
                n_attr, n_attr, n_route, dis_mat, kt, droprate, temperature)
        # elif opt.STgpr['use']:
        #     kt, temperature = opt.STgpr['kt'], opt.STdrop['temperature']
        #     self.stgc = STAttnGraphConv_gpr(
        #         n_his, n_attr, n_attr, dis_mat, kt, temperature)
        # if opt.STR['use']:
        #     n_head, d_q, d_k, d_c, kt, normal = opt.STR['n_head'], opt.STR['d_q'], opt.STR['d_k'], opt.STR['d_c'], opt.STR['kt'], opt.STR['normal']
        #     self.stgc = STRAttnGraphConv(n_route, n_his, n_attr, n_attr, dis_mat, n_head, d_q, d_k, d_c, kt, normal)

        # if opt.STG['use']:
        #     n_head, d_q, d_k, d_c, kt, normal = opt.STG['n_head'], opt.STG['d_q'], opt.STG['d_k'], opt.STG['d_c'], opt.STG['kt'], opt.STG['normal']
        #     self.stgc = STGAttnGraphConv(n_route, n_his, n_attr, n_attr, dis_mat, n_head, d_q, d_k, d_c, kt, normal)

        # if opt.version == 1:
        #     n_his, d_attribute, d_out, d_q, d_c, kt = opt.TG['n_his'], opt.TG['d_attribute'], opt.TG['d_out'], opt.TG['d_q'], opt.TG['d_c'], opt.TG['kt']
        #     self.stgc = temporalGraphConv(n_his, d_attribute, d_out, dis_mat, d_q, d_c, kt)
        # elif opt.version == 2:
        #     n_his, d_attribute, d_out, d_q, d_c, kt = opt.TG2['n_his'], opt.TG2['d_attribute'], opt.TG2['d_out'], opt.TG2['d_q'], opt.TG2['d_c'], opt.TG2['kt']
        #     self.stgc = temporalGraphConv_ver2(n_his, d_attribute, d_out, dis_mat, d_q, d_c, kt)
        # elif opt.version == 5:
        #     n_his, d_attribute, d_out, n_head, d_q, d_c, kt, temperature = opt.TG5['n_his'], opt.TG5['d_attribute'], opt.TG5['d_out'], opt.TG5['n_head'], opt.TG5['d_q'], opt.TG5['d_c'], opt.TG5['kt'], opt.TG5['temperature']
        #     self.stgc = temporalGraphConv_ver5(n_his, d_attribute, d_out, dis_mat, d_q, d_c, kt, temperature)
        # elif opt.version == 6:
        #     n_his, d_attribute, d_out, d_q, d_c, kt = opt.TG6['n_his'], opt.TG6['d_attribute'], opt.TG6['d_out'], opt.TG6['d_q'], opt.TG6['d_c'], opt.TG6['kt']
        #     self.stgc = temporalGraphConv_ver6(n_his, d_attribute, d_out, dis_mat, d_q, d_c, kt)
        # elif opt.version == 7:
        #     n_his, d_attribute, d_out, d_q, d_c, kt = opt.TG7['n_his'], opt.TG7['d_attribute'], opt.TG7['d_out'], opt.TG7['d_q'], opt.TG7['d_c'], opt.TG7['kt']
        #     self.stgc = temporalGraphConv_ver7(n_his, d_attribute, d_out, dis_mat, d_q, d_c, kt)

        # self.static = opt.static['use']
        # if self.static:

        self.pos_ff1 = PositionwiseFeedForward(n_attr, n_hid, opt.drop_prob)
        self.pos_ff2 = PositionwiseFeedForward(n_attr, n_hid, opt.drop_prob)

    def forward(self, x):
        # print(x.shape)
        x = self.tem_attn(x, x, x, self.tem_mask)
        # print(x.shape)

        # print(x.shape)
        # x = self.pos_ff(x)

        x = self.pos_ff1(x)
        x = self.stgc(x)
        x = self.pos_ff2(x)

        return x


class EncoderLayer_stamp(nn.Module):
    def __init__(
        self,
        opt,
        spa_mask, tem_mask
    ):
        super().__init__()

        n_route, n_his, n_attr, n_hid = opt.n_route, opt.n_his, opt.n_attr, opt.n_hid

        dis_mat = opt.dis_mat
        # cor_mat = opt.cor_mat

        self.tem_attn = MultiHeadAttention(
            opt.attn['head'], n_attr, opt.attn['d_k'], opt.attn['d_v'], opt.attn['drop_prob'])
        self.tem_mask = tem_mask
        assert opt.STstamp['use'], "encoder_stamp requires time stamp as input."
        module_switch = TPGablation()
        kt, droprate, temperature = opt.STstamp['kt'], opt.drop_prob, opt.STstamp['temperature']
        self.stgc = module_switch.get(opt.TPG, n_attr, n_attr, n_route,n_his, dis_mat, kt, opt.n_c, droprate, temperature)
        # if opt.ST['use']:
        #     n_head, d_q, d_k, d_c, kt, normal = opt.ST['n_head'], opt.ST[
        #         'd_q'], opt.ST['d_k'], opt.ST['d_c'], opt.ST['kt'], opt.ST['normal']
        #     self.stgc = STAttnGraphConv(
        #         n_route, n_his, n_attr, n_attr, dis_mat, n_head, d_q, d_k, d_c, kt, normal)
        # elif opt.STO['use']:
        #     n_head, d_q, d_k, d_c, kt, normal = opt.STO['n_head'], opt.STO[
        #         'd_q'], opt.STO['d_k'], opt.STO['d_c'], opt.STO['kt'], opt.STO['normal']
        #     self.stgc = STAttnGraphConv_verO(
        #         n_route, n_his, n_attr, n_attr, dis_mat, n_head, d_q, d_k, d_c, kt, normal)
        # elif opt.STdrop['use']:
        #     kt, droprate, temperature = opt.STdrop['kt'], opt.drop_prob, opt.STdrop['temperature']
        #     self.stgc = STAttnGraphConv_drop(
        #         n_attr, n_attr, n_route, dis_mat, kt, droprate, temperature)
        # elif opt.STgpr['use']:
        #     kt, temperature = opt.STgpr['kt'], opt.STdrop['temperature']
        #     self.stgc = STAttnGraphConv_gpr(
        #         n_his, n_attr, n_attr, dis_mat, kt, temperature)
        # if opt.STR['use']:
        #     n_head, d_q, d_k, d_c, kt, normal = opt.STR['n_head'], opt.STR['d_q'], opt.STR['d_k'], opt.STR['d_c'], opt.STR['kt'], opt.STR['normal']
        #     self.stgc = STRAttnGraphConv(n_route, n_his, n_attr, n_attr, dis_mat, n_head, d_q, d_k, d_c, kt, normal)

        # if opt.STG['use']:
        #     n_head, d_q, d_k, d_c, kt, normal = opt.STG['n_head'], opt.STG['d_q'], opt.STG['d_k'], opt.STG['d_c'], opt.STG['kt'], opt.STG['normal']
        #     self.stgc = STGAttnGraphConv(n_route, n_his, n_attr, n_attr, dis_mat, n_head, d_q, d_k, d_c, kt, normal)

        # if opt.version == 1:
        #     n_his, d_attribute, d_out, d_q, d_c, kt = opt.TG['n_his'], opt.TG['d_attribute'], opt.TG['d_out'], opt.TG['d_q'], opt.TG['d_c'], opt.TG['kt']
        #     self.stgc = temporalGraphConv(n_his, d_attribute, d_out, dis_mat, d_q, d_c, kt)
        # elif opt.version == 2:
        #     n_his, d_attribute, d_out, d_q, d_c, kt = opt.TG2['n_his'], opt.TG2['d_attribute'], opt.TG2['d_out'], opt.TG2['d_q'], opt.TG2['d_c'], opt.TG2['kt']
        #     self.stgc = temporalGraphConv_ver2(n_his, d_attribute, d_out, dis_mat, d_q, d_c, kt)
        # elif opt.version == 5:
        #     n_his, d_attribute, d_out, n_head, d_q, d_c, kt, temperature = opt.TG5['n_his'], opt.TG5['d_attribute'], opt.TG5['d_out'], opt.TG5['n_head'], opt.TG5['d_q'], opt.TG5['d_c'], opt.TG5['kt'], opt.TG5['temperature']
        #     self.stgc = temporalGraphConv_ver5(n_his, d_attribute, d_out, dis_mat, d_q, d_c, kt, temperature)
        # elif opt.version == 6:
        #     n_his, d_attribute, d_out, d_q, d_c, kt = opt.TG6['n_his'], opt.TG6['d_attribute'], opt.TG6['d_out'], opt.TG6['d_q'], opt.TG6['d_c'], opt.TG6['kt']
        #     self.stgc = temporalGraphConv_ver6(n_his, d_attribute, d_out, dis_mat, d_q, d_c, kt)
        # elif opt.version == 7:
        #     n_his, d_attribute, d_out, d_q, d_c, kt = opt.TG7['n_his'], opt.TG7['d_attribute'], opt.TG7['d_out'], opt.TG7['d_q'], opt.TG7['d_c'], opt.TG7['kt']
        #     self.stgc = temporalGraphConv_ver7(n_his, d_attribute, d_out, dis_mat, d_q, d_c, kt)

        # self.static = opt.static['use']
        # if self.static:

        self.pos_ff1 = PositionwiseFeedForward(n_attr, n_hid, opt.drop_prob)
        self.pos_ff2 = PositionwiseFeedForward(n_attr, n_hid, opt.drop_prob)

    def forward(self, x, stamp):
        # print(x.shape)
        x = self.tem_attn(x, x, x, self.tem_mask)
        # print(x.shape)

        # print(x.shape)
        # x = self.pos_ff(x)

        x = self.pos_ff1(x)
        x = self.stgc(x, stamp)
        x = self.pos_ff2(x)

        return x


class DecoderLayer(nn.Module):
    def __init__(
        self,
        opt,
        slf_mask, mul_mask
    ):
        super().__init__()

        n_attr, n_hid = opt.n_attr, opt.n_hid

        self.slf_attn = MultiHeadAttention(
            opt.attn['head'], n_attr, opt.attn['d_k'], opt.attn['d_v'], opt.attn['drop_prob'])
        self.slf_mask = slf_mask

        self.mul_attn = MultiHeadAttention(
            opt.attn['head'], n_attr, opt.attn['d_k'], opt.attn['d_v'], opt.attn['drop_prob'])
        self.mul_mask = mul_mask

        self.pos_ff = PositionwiseFeedForward(n_attr, n_hid, opt.drop_prob)

        # dis_mat = opt.dis_mat
        # n_his, d_attribute, d_out, d_q, d_c, kt, temperature = opt.TG5['n_his'], opt.TG5['d_attribute'], opt.TG5['d_out'], opt.TG5['d_q'], opt.TG5['d_c'], opt.TG5['kt'], opt.TG5['temperature']
        # self.stgc = temporalGraphConv_ver5(n_his, d_attribute, d_out, dis_mat, d_q, d_c, kt, temperature)

    def forward(self, x, enc_output):
        x = self.mul_attn(x, enc_output, enc_output, self.mul_mask)
        x = self.pos_ff(x)

        # x = self.stgc(x)
        return x
